import torch
import torchvision
from PIL import Image

from model import *

image_path = "./test_images/dog.jpg"
image = Image.open(image_path)

transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),torchvision.transforms.ToTensor()])

model = torch.load("my_model_10.pth")
model.eval()
image = transform(image)
print(image.shape)
image = torch.reshape(image,(1,3,32,32))
print(image.shape)

with torch.no_grad():
    output = model(image)

print(output)
print(output.argmax(1))
